/*
 * SPDX-FileCopyrightText: 2016 Cesanta Software Limited
 *
 * SPDX-License-Identifier: GPL-2.0-or-later
 *
 * SPDX-FileContributor: 2016-2022 Espressif Systems (Shanghai) CO LTD
 */

#include "soc_support.h"
#include "stub_write_flash.h"
#include "stub_flasher.h"
#include "rom_functions.h"
#include "miniz.h"

/* local flashing state

   This is wrapped in a structure because gcc 4.8
   generates significantly more code for ESP32
   if they are static variables (literal pool, I think!)
*/
static struct {
  /* set by flash_begin, cleared by flash_end */
  bool in_flash_mode;
  /* offset of next SPI write */
  uint32_t next_write;
  /* sector number for next erase */
  int next_erase_sector;
  /* number of output bytes remaining to write */
  uint32_t remaining;
  /* number of sectors remaining to erase */
  int remaining_erase_sector;
  /* last error generated by a data packet */
  esp_command_error last_error;

  /* inflator state for deflate write */
  tinfl_decompressor inflator;
  /* number of compressed bytes remaining to read */
  uint32_t remaining_compressed;
} fs;

/* SPI status bits */
static const uint32_t STATUS_WIP_BIT = (1 << 0);
#if ESP32_OR_LATER
static const uint32_t STATUS_QIE_BIT = (1 << 9);  /* Quad Enable */
#endif

bool is_in_flash_mode(void)
{
  return fs.in_flash_mode;
}

esp_command_error get_flash_error(void)
{
  return fs.last_error;
}

/* Wait for the SPI state machine to be ready,
   ie no command in progress in the internal host.
*/
inline static void spi_wait_ready(void)
{
  /* Wait for SPI state machine ready */
  while((READ_REG(SPI_EXT2_REG) & SPI_ST))
    { }
#if ESP32_OR_LATER
  while(READ_REG(SPI0_EXT2_REG) & SPI_ST)
  { }
#endif
}

/* Returns true if the spiflash is ready for its next write
   operation.

   Doesn't block, except for the SPI state machine to finish
   any previous SPI host operation.
*/
static bool spiflash_is_ready(void)
{
  spi_wait_ready();
  WRITE_REG(SPI_RD_STATUS_REG, 0);
  /* Issue read status command */
  WRITE_REG(SPI_CMD_REG, SPI_FLASH_RDSR);
  while(READ_REG(SPI_CMD_REG) != 0)
    { }
  uint32_t status_value = READ_REG(SPI_RD_STATUS_REG);
  return (status_value & STATUS_WIP_BIT) == 0;
}

static void spi_write_enable(void)
{
  while(!spiflash_is_ready())
    { }
  WRITE_REG(SPI_CMD_REG, SPI_FLASH_WREN);
  while(READ_REG(SPI_CMD_REG) != 0)
    { }
}

#if ESP32_OR_LATER
static esp_rom_spiflash_chip_t *flashchip = (esp_rom_spiflash_chip_t *)ROM_SPIFLASH_LEGACY;

/* Stub version of SPIUnlock() that replaces version in ROM.

   This works around a bug where SPIUnlock sometimes reads the wrong
   high status byte (RDSR2 result) and then copies it back to the
   flash status, causing lock bit CMP or Status Register Protect ` to
   become set.
 */
SpiFlashOpResult SPIUnlock(void)
{
  uint32_t status;

  spi_wait_ready(); /* ROM SPI_read_status_high() doesn't wait for this */
#if ESP32S2_OR_LATER
  if (SPI_read_status_high(flashchip, &status) != SPI_FLASH_RESULT_OK) {
    return SPI_FLASH_RESULT_ERR;
  }
#else
  if (SPI_read_status_high(&status) != SPI_FLASH_RESULT_OK) {
    return SPI_FLASH_RESULT_ERR;
  }
#endif // ESP32S2_OR_LATER

  /* Clear all bits except QIE, if it is set.
     (This is different from ROM SPIUnlock, which keeps all bits as-is.)
   */
  status &= STATUS_QIE_BIT;

  spi_write_enable();

  REG_SET_MASK(SPI_CTRL_REG, SPI_WRSR_2B);
  if (SPI_write_status(flashchip, status) != SPI_FLASH_RESULT_OK) {
    return SPI_FLASH_RESULT_ERR;
  }

  return SPI_FLASH_RESULT_OK;
}
#endif // ESP32_OR_LATER

#if defined(ESP32S3) && !defined(ESP32S3BETA2)
static esp_rom_spiflash_result_t page_program_internal(int spi_num, uint32_t spi_addr, uint8_t* addr_source, uint32_t byte_length)
{
    uint32_t  temp_addr;
    int32_t  temp_bl;
    esp_rom_opiflash_wait_idle();
    temp_addr = spi_addr;
    temp_bl = byte_length;
    uint32_t temp_len = 0;

    const uint16_t cmd = CMD_PROGRAM_PAGE_4B;
    uint8_t cmd_len = 8;
    int dummy = 0;

    while (temp_bl > 0 ) {
        esp_rom_opiflash_wren();
        temp_len =  (temp_bl >= 32) ? 32 : temp_bl;   //32 = write_sub_len
        esp_rom_opiflash_exec_cmd(spi_num, SPI_FLASH_FASTRD_MODE,
                            cmd, cmd_len,
                            temp_addr, 32,
                            dummy,
                            addr_source, 8 * temp_len,
                            NULL, 0,
                            ESP_ROM_OPIFLASH_SEL_CS0,
                            true);
        esp_rom_opiflash_wait_idle();
        addr_source += temp_len;
        temp_addr += temp_len;
        temp_bl -= temp_len;
    }
    return ESP_ROM_SPIFLASH_RESULT_OK;
}
#endif // ESP32S3

#if defined(ESP32S3) && !defined(ESP32S3BETA2)
static esp_rom_spiflash_result_t SPIWrite4B(int spi_num, uint32_t target, uint8_t *src_addr, int32_t len)
{
    uint32_t  page_size = 256;
    uint32_t  pgm_len, pgm_num;
    uint8_t    i;

    esp_rom_opiflash_wait_idle();
    pgm_len = page_size - (target % page_size);
    if (len < pgm_len) {
        page_program_internal(spi_num, target, src_addr, len);
    } else {
        page_program_internal(spi_num, target, src_addr, pgm_len);
        //whole page program
        pgm_num = (len - pgm_len) / page_size;
        for (i = 0; i < pgm_num; i++) {
            page_program_internal(spi_num, target + pgm_len, (src_addr + pgm_len), page_size);
            pgm_len += page_size;
        }
        //remain parts to program
        page_program_internal(spi_num, target + pgm_len, (src_addr + pgm_len), len - pgm_len);
    }
    esp_rom_opiflash_wait_idle();
    return  ESP_ROM_SPIFLASH_RESULT_OK;
}
#endif // defined(ESP32S3) && !defined(ESP32S3BETA2)

esp_command_error handle_flash_begin(uint32_t total_size, uint32_t offset) {
  fs.in_flash_mode = true;
  fs.next_write = offset;
  fs.next_erase_sector = offset / FLASH_SECTOR_SIZE;
  fs.remaining = total_size;
  fs.remaining_erase_sector = ((offset % FLASH_SECTOR_SIZE) + total_size + FLASH_SECTOR_SIZE - 1) / FLASH_SECTOR_SIZE;
  fs.last_error = ESP_OK;

#if defined(ESP32S3) && !defined(ESP32S3BETA2)
  if (large_flash_mode) {
    esp_rom_opiflash_wait_idle();
  } else {
    if (SPIUnlock() != 0) {
        return ESP_FAILED_SPI_UNLOCK;
    }
  }
#else
  if (SPIUnlock() != 0) {
    return ESP_FAILED_SPI_UNLOCK;
  }
#endif //defined(ESP32S3) and !defined(ESP32S3BETA2)

  return ESP_OK;
}

esp_command_error handle_flash_deflated_begin(uint32_t uncompressed_size, uint32_t compressed_size, uint32_t offset) {
  esp_command_error err = handle_flash_begin(uncompressed_size, offset);
  tinfl_init(&fs.inflator);
  fs.remaining_compressed = compressed_size;
  return err;
}

/* Erase the next sector or block (depending if we're at a block boundary).

   Updates fs.next_erase_sector & fs.remaining_erase_sector on success.

   If nothing left to erase, returns immediately.

   Returns immediately if SPI flash not yet ready for a write operation.

   Does not wait for the erase to complete - the next SPI operation
   should check if a write operation is currently in progress.
 */
static void start_next_erase(void)
{
  bool block_erase = false;

  if(fs.remaining_erase_sector == 0)
    return; /* nothing left to erase */
  if(!spiflash_is_ready())
    return; /* don't wait for flash to be ready, caller will call again if needed */

  if(fs.remaining_erase_sector >= SECTORS_PER_BLOCK
     && fs.next_erase_sector % SECTORS_PER_BLOCK == 0) {
    /* perform a 64KB block erase if we have space for it */
    block_erase = true;
  }

  spi_write_enable();
  spi_wait_ready();
  #if defined(ESP32S3) && !defined(ESP32S3BETA2)
      if (large_flash_mode) {
        if (block_erase) {
          if (fs.next_erase_sector * FLASH_SECTOR_SIZE < (1 << 24)) {
            esp_rom_opiflash_wait_idle();
            esp_rom_opiflash_wren();

            esp_rom_opiflash_exec_cmd(1, SPI_FLASH_SLOWRD_MODE,
                                CMD_LARGE_BLOCK_ERASE, 8,
                                fs.next_erase_sector * FLASH_SECTOR_SIZE, 24,
                                0,
                                NULL, 0,
                                NULL, 0,
                                1,
                                true);
            esp_rom_opiflash_wait_idle();
          } else {
            esp_rom_opiflash_erase_block_64k(fs.next_erase_sector / SECTORS_PER_BLOCK);
          }
        }
        else {
          if (fs.next_erase_sector * FLASH_SECTOR_SIZE < (1 << 24)) {
            esp_rom_opiflash_wait_idle();
            esp_rom_opiflash_wren();

            esp_rom_opiflash_exec_cmd(1, SPI_FLASH_SLOWRD_MODE,
                                CMD_SECTOR_ERASE, 8,
                                fs.next_erase_sector * FLASH_SECTOR_SIZE, 24,
                                0,
                                NULL, 0,
                                NULL, 0,
                                1,
                                true);
            esp_rom_opiflash_wait_idle();
          } else {
            esp_rom_opiflash_erase_sector(fs.next_erase_sector);
          }
        }
      } else {
          uint32_t addr = fs.next_erase_sector * FLASH_SECTOR_SIZE;
          uint32_t command = block_erase ? SPI_FLASH_BE : SPI_FLASH_SE; /* block erase, 64KB : sector erase, 4KB */
          WRITE_REG(SPI_ADDR_REG, addr & 0xffffff);
          WRITE_REG(SPI_CMD_REG, command);
          while(READ_REG(SPI_CMD_REG) != 0) { }
      }
  #else
    uint32_t addr = fs.next_erase_sector * FLASH_SECTOR_SIZE;
    uint32_t command = block_erase ? SPI_FLASH_BE : SPI_FLASH_SE; /* block erase, 64KB : sector erase, 4KB */
    WRITE_REG(SPI_ADDR_REG, addr & 0xffffff);
    WRITE_REG(SPI_CMD_REG, command);
    while(READ_REG(SPI_CMD_REG) != 0) { }
  #endif // defined(ESP32S3) && !defined(ESP32S3BETA2)

  uint32_t sectors_to_erase = block_erase ? SECTORS_PER_BLOCK : 1;
  fs.remaining_erase_sector -= sectors_to_erase;
  fs.next_erase_sector += sectors_to_erase;
}

/* Write data to flash (either direct for non-compressed upload, or
   freshly decompressed.) Erases as it goes.

   Updates fs.remaining_erase_sector, fs.next_write, and fs.remaining
*/
void handle_flash_data(void *data_buf, uint32_t length) {
  int last_sector;
  uint8_t res = 0;

  if (length > fs.remaining) {
      /* Trim the final block, as it may have padding beyond
         the length we are writing */
      length = fs.remaining;
  }

  if (length == 0) {
      return;
  }

  /* what sector is this write going to end in?
     make sure we've erased at least that far.
  */
  last_sector = (fs.next_write + length) / FLASH_SECTOR_SIZE;
  while(fs.remaining_erase_sector > 0 && fs.next_erase_sector <= last_sector) {
    start_next_erase();
  }
  while(!spiflash_is_ready())
    {}

  /* do the actual write */
  #if defined(ESP32S3) && !defined(ESP32S3BETA2)
      if (large_flash_mode){
        res = SPIWrite4B(1, fs.next_write, data_buf, length);
      } else {
        res = SPIWrite(fs.next_write, data_buf, length);
      }
  #else
    res = SPIWrite(fs.next_write, data_buf, length);
  #endif // defined(ESP32S3) && !defined(ESP32S3BETA2)
  if (res != 0)
    fs.last_error = ESP_FAILED_SPI_OP;
  fs.next_write += length;
  fs.remaining -= length;
}

#if !ESP8266
/* Write encrypted data to flash (either direct for non-compressed upload, or
   freshly decompressed.) Erases as it goes.

   Updates fs.remaining_erase_sector, fs.next_write, and fs.remaining
*/
void handle_flash_encrypt_data(void *data_buf, uint32_t length) {
  int last_sector;
  int res;

#if ESP32S2_OR_LATER
  SPI_Write_Encrypt_Enable();
#endif

  if (length > fs.remaining) {
      /* Trim the final block, as it may have padding beyond
         the length we are writing */
      length = fs.remaining;
  }

  if (length == 0) {
      return;
  }

  /* what sector is this write going to end in?
     make sure we've erased at least that far.
  */
  last_sector = (fs.next_write + length) / FLASH_SECTOR_SIZE;
  while(fs.remaining_erase_sector > 0 && fs.next_erase_sector <= last_sector) {
    start_next_erase();
  }
  while(!spiflash_is_ready())
    {}

  /* do the actual write */
#if ESP32
  res = esp_rom_spiflash_write_encrypted(fs.next_write, data_buf, length);
#else
  res = SPI_Encrypt_Write(fs.next_write, data_buf, length);
#endif

  if (res) {
    fs.last_error = ESP_FAILED_SPI_OP;
  }
  fs.next_write += length;
  fs.remaining -= length;

#if ESP32S2_OR_LATER
  SPI_Write_Encrypt_Disable();
#endif
}

#endif // !ESP8266

void handle_flash_deflated_data(void *data_buf, uint32_t length) {
  /* if all data has been uploaded and another block comes,
     accept it only if it is part of a 4-byte Adler-32 checksum */
  if (fs.remaining == 0 && length > 4) {
    fs.last_error = ESP_TOO_MUCH_DATA;
    return;
  }

  static uint8_t out_buf[32768];
  static uint8_t *next_out = out_buf;
  int status = TINFL_STATUS_NEEDS_MORE_INPUT;

  while(length > 0 && fs.remaining > 0 && status > TINFL_STATUS_DONE) {
    size_t in_bytes = length; /* input remaining */
    size_t out_bytes = out_buf + sizeof(out_buf) - next_out; /* output space remaining */
    int flags = TINFL_FLAG_PARSE_ZLIB_HEADER;
    if(fs.remaining_compressed > length) {
      flags |= TINFL_FLAG_HAS_MORE_INPUT;
    }

    /* start an opportunistic erase: decompressing takes time, so might as
       well be running a SPI erase in the background. */
    start_next_erase();

    status = tinfl_decompress(&fs.inflator, data_buf, &in_bytes,
                     out_buf, next_out, &out_bytes,
                     flags);

    fs.remaining_compressed -= in_bytes;
    length -= in_bytes;
    data_buf += in_bytes;

    next_out += out_bytes;
    size_t bytes_in_out_buf = next_out - out_buf;
    if (status == TINFL_STATUS_DONE || bytes_in_out_buf == sizeof(out_buf)) {
      // Output buffer full, or done
      handle_flash_data(out_buf, bytes_in_out_buf);
      next_out = out_buf;
    }
  } // while

  if (status < TINFL_STATUS_DONE) {
    /* error won't get sent back to esptool.py until next block is sent */
    fs.last_error = ESP_INFLATE_ERROR;
  }

  if (status == TINFL_STATUS_DONE && fs.remaining > 0) {
    fs.last_error = ESP_NOT_ENOUGH_DATA;
  }
}

esp_command_error handle_flash_end(void)
{
  if (!fs.in_flash_mode) {
    return ESP_NOT_IN_FLASH_MODE;
  }

  if (fs.remaining > 0) {
    return ESP_NOT_ENOUGH_DATA;
  }

  fs.in_flash_mode = false;
  return fs.last_error;
}
